Data Preprocessing, Model Loading, Prediction, Evaluation¶

This notebook shows how to preprocess audio files, load a trained model, how to predict pitches and evaluate the estimates.

In [1]:
import os
import sys
basepath = os.path.dirname(os.path.abspath('.'))
sys.path.append(basepath)

import numpy as np
import pandas as pd
import librosa
import libfmp
import matplotlib.pyplot as plt
import IPython.display as ipd
import torch
import torchinfo

import libdl
In [2]:
# CPU / GPU 
device = torch.device('cpu')
# device = torch.device('cuda')

1. Load and preprocess audio¶

Load audio file¶

In [3]:
fs = 22050

audio_folder = os.path.join(basepath, 'data', 'Schubert_Winterreise', '01_RawData', 'audio_wav')
fn_audio = 'Schubert_D911-23_SC06.wav'

# Load audio
path_audio = os.path.join(audio_folder, fn_audio)
f_audio, fs_load = librosa.load(path_audio, sr=fs)
In [4]:
libfmp.b.plot_signal(f_audio, Fs=fs_load)
ipd.display(ipd.Audio(data=f_audio, rate=fs_load))
Your browser does not support the audio element.

Compute HCQT¶

In [5]:
# HCQT parameters
bins_per_semitone = 3
hcqt_config = {
    'fs': fs,
    'fmin': librosa.note_to_hz('C1'),  # MIDI pitch 24
    'fs_hcqt_target': 50,
    'bins_per_octave': 12 * bins_per_semitone,
    'num_octaves': 6,
    'num_harmonics': 5,
    'num_subharmonics': 1,
    'center_bins': True,
}

# Compute HCQT
f_hcqt, fs_hcqt, hopsize_cqt = libdl.data_preprocessing.compute_efficient_hcqt(f_audio, **hcqt_config);

Visualize first harmonic¶

In [6]:
def plot_matrix_with_ticks(data, title, bins_per_semitone=bins_per_semitone, 
                           hcqt_config=hcqt_config, fs_hqct=fs_hcqt, pitches=True, **kwargs):
    vis_start_sec = 25
    vis_stop_sec = 50
    vis_step_sec = 5
    
    n_bins = bins_per_semitone*12*hcqt_config["num_octaves"]

    plt.rcParams.update({'font.size': 11})
    fig, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1, 0.05]}, figsize=(10, 3.5))
    im = libfmp.b.plot_matrix(data[:, int(vis_start_sec*fs_hcqt):int(vis_stop_sec*fs_hcqt)], 
                              Fs=fs_hcqt, ax=ax, cmap='gray_r', ylabel='MIDI pitch', **kwargs)
    
    if pitches:
        ax[0].set_yticks(np.arange(0, 73, 12))
        ax[0].set_yticklabels([str(24+12*octave) for octave in range(0, hcqt_config["num_octaves"]+1)])
    else:
        ax[0].set_yticks(np.arange(1, n_bins+13, 12*bins_per_semitone))
        ax[0].set_yticklabels([str(24+12*octave) for octave in range(0, hcqt_config["num_octaves"]+1)])
    ax[0].set_xticks(np.arange(0, (vis_stop_sec-vis_start_sec)+vis_step_sec, vis_step_sec))
    ax[0].set_xticklabels(np.arange(vis_start_sec, vis_stop_sec+vis_step_sec, vis_step_sec))
    ax[0].set_title(title)
    plt.tight_layout()
In [7]:
plot_matrix_with_ticks(data=np.log(1+10*np.abs(f_hcqt[:, :, 1])), title='Harmonic 1 (fundamental)', pitches=False)

2. Specify and load model¶

In [8]:
dir_models = os.path.join(basepath, 'experiments', 'models')

# fn_model = '02_schubert_baseline_ae.pt'
# fn_model = '03_schubert_baseline_sup.pt'  
# fn_model = '04_schubert_cva.pt'
# fn_model = '05_schubert_cva_ov.pt'
# fn_model = '06_schubert_cva_b.pt'
fn_model = '07_schubert_cva_ov_b.pt'
In [9]:
# Model parameters
num_octaves_inp = 6
num_output_bins, min_pitch = 72, 24
model_params = {
    'n_chan_input': 6,
    'n_chan_layers': [20, 20, 10, 1],
    'n_bins_in': num_octaves_inp * 12 * 3,
    'n_bins_out': num_output_bins,
    'a_lrelu': 0.3,
    'p_dropout': 0.2
}

if fn_model == '03_schubert_baseline_sup.pt':
    # Model without final sigmoid activation; only for 03_schubert_baseline_sup 
    model = libdl.nn_models.basic_cnn_segm_logit(**model_params)
else:
    # Model with final sigmoid activation
    model = libdl.nn_models.basic_cnn_segm_sigmoid(**model_params)
In [10]:
# Load trained model
model.load_state_dict(torch.load(os.path.join(dir_models, fn_model), map_location=device))

model.to(device)
model.eval();
In [11]:
torchinfo.summary(model, input_size=(1, 6, 574, 216), device=device)
Out[11]:
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
basic_cnn_segm_sigmoid                   [1, 1, 500, 72]           --
├─LayerNorm: 1-1                         [1, 574, 6, 216]          2,592
├─Sequential: 1-2                        [1, 20, 574, 216]         --
│    └─Conv2d: 2-1                       [1, 20, 574, 216]         27,020
│    └─LeakyReLU: 2-2                    [1, 20, 574, 216]         --
│    └─MaxPool2d: 2-3                    [1, 20, 574, 216]         --
│    └─Dropout: 2-4                      [1, 20, 574, 216]         --
├─Sequential: 1-3                        [1, 20, 574, 72]          --
│    └─Conv2d: 2-5                       [1, 20, 574, 72]          3,620
│    └─LeakyReLU: 2-6                    [1, 20, 574, 72]          --
│    └─MaxPool2d: 2-7                    [1, 20, 574, 72]          --
│    └─Dropout: 2-8                      [1, 20, 574, 72]          --
├─Sequential: 1-4                        [1, 10, 500, 72]          --
│    └─Conv2d: 2-9                       [1, 10, 500, 72]          15,010
│    └─LeakyReLU: 2-10                   [1, 10, 500, 72]          --
│    └─Dropout: 2-11                     [1, 10, 500, 72]          --
├─Sequential: 1-5                        [1, 1, 500, 72]           --
│    └─Conv2d: 2-12                      [1, 1, 500, 72]           11
│    └─LeakyReLU: 2-13                   [1, 1, 500, 72]           --
│    └─Dropout: 2-14                     [1, 1, 500, 72]           --
│    └─Conv2d: 2-15                      [1, 1, 500, 72]           2
│    └─Sigmoid: 2-16                     [1, 1, 500, 72]           --
==========================================================================================
Total params: 48,255
Trainable params: 48,255
Non-trainable params: 0
Total mult-adds (G): 4.04
==========================================================================================
Input size (MB): 2.98
Forward/backward pass size (MB): 35.86
Params size (MB): 0.19
Estimated Total Size (MB): 39.03
==========================================================================================

3. Predict pitches¶

Create dataset object¶

In [12]:
test_dataset_params = {
    'context': 75,
    'compression': 10   # log-compression applied to HCQT
}

half_context = test_dataset_params['context'] // 2

inputs = np.transpose(f_hcqt, (2, 1, 0))

# Pad input in order to account for context frames
inputs_context = torch.from_numpy(np.pad(inputs, ((0, 0), (half_context, half_context+1), (0, 0))))

# Create dummy targets for dataset object
targets_context = torch.zeros(inputs_context.shape[1], num_output_bins)

test_dataset_params['seglength'] = inputs.shape[1]  # dataset will then contain only 1 segment which includes all frames
test_dataset_params['stride'] = inputs.shape[1]

test_set = libdl.data_loaders.dataset_context_segm(inputs_context, targets_context, test_dataset_params)

Make prediction¶

In [13]:
test_batch, _ = test_set[0]

# Batch format
test_batch = test_batch.unsqueeze(dim=0).to(device)

# Predict
y_pred = model(test_batch)

# Apply sigmoid activation if not contained as last layer in model
if model.__class__ == libdl.nn_models.basic_cnns_mctc.basic_cnn_segm_logit:
    y_pred = torch.sigmoid(y_pred)

# Convert prediction to Numpy array
pred = y_pred.to('cpu').detach().squeeze().numpy()
In [14]:
plot_matrix_with_ticks(data=pred.T, title='Pitch prediction', pitches=True, clim=[0.0, 1.0])

(Visualize predictions + overtone model / bias)¶

In [15]:
def overtone_model(pred):
    shifts = [12, 19, 24, 28, 31, 34, 36, 38, 40]
    strengths = 0.9 ** np.array(shifts)

    w_overtones = torch.clone(pred)
    for shift, strength in zip(shifts, strengths):
        w_overtones[:, :, shift:] += strength * pred[:, :, :-shift]
    return torch.clip(w_overtones, 0.0, 1.0)

pred_ov = overtone_model(y_pred.squeeze(dim=1))
pred_ov_np = pred_ov.to('cpu').detach().squeeze().numpy()

plot_matrix_with_ticks(data=pred_ov_np.T, title='Pitch prediction + Ov', pitches=True, clim=[0.0, 1.0])
In [16]:
bias = 0.2
pred_ov_b = torch.clip(pred_ov + bias, 0.0, 1.0).to('cpu').detach().squeeze().numpy()

plot_matrix_with_ticks(data=pred_ov_b.T, title='Pitch prediction + Ov + B', pitches=True, clim=[0.0, 1.0])

4. Load and convert annotations¶

In [17]:
annot_folder = os.path.join(basepath, 'data', 'Schubert_Winterreise', '02_Annotations', 'ann_audio_note')
fn_annot = os.path.join(annot_folder, fn_audio[:-4]+'.csv')

if os.path.exists(fn_annot):
    df = pd.read_csv(fn_annot, sep=';', skiprows=1, header=None)
    note_events = df.to_numpy()[:, :3]

    f_annot_pitch = libdl.data_preprocessing.compute_annotation_array_nooverlap(note_events, f_hcqt, fs_hcqt, 
                                                                               annot_type='pitch', shorten=1.0)
In [18]:
if os.path.exists(fn_annot):
    plot_matrix_with_ticks(data=f_annot_pitch[24:97], title='Pitch annotations', pitches=True)

5. Multi-pitch evaluation¶

In [19]:
eval_measures = ['precision', 'recall', 'f_measure', 'cosine_sim', 'binary_crossentropy', 'euclidean_distance',
                 'binary_accuracy', 'soft_accuracy', 'accum_energy', 'roc_auc_measure', 'average_precision_score']

eval_thresh = 0.4
In [20]:
# Thresholding
pred_th = (pred > eval_thresh).astype(float)

plot_matrix_with_ticks(data=pred_th.T, title=f'Pitch prediction after thresholding (tau={eval_thresh})', pitches=True)
In [21]:
if os.path.exists(fn_annot):
    # Calculate metrics
    targ = np.transpose(f_annot_pitch, (1, 0))[:, min_pitch:(min_pitch+num_output_bins)]

    eval_dict = libdl.metrics.calculate_eval_measures(targ, pred, measures=eval_measures, threshold=eval_thresh, save_roc_plot=False)
    eval_numbers = np.fromiter(eval_dict.values(), dtype=float)

    metrics_mpe = libdl.metrics.calculate_mpe_measures_mireval(targ, pred, threshold=eval_thresh, min_pitch=min_pitch)
    mireval_measures = [key for key in metrics_mpe.keys()]
    mireval_numbers = np.fromiter(metrics_mpe.values(), dtype=float)
In [22]:
if os.path.exists(fn_annot):
    for i, meas_name in enumerate(eval_measures):
        print(f'{meas_name:<30} {eval_numbers[i]}')

    print('')

    for i, meas_name in enumerate(mireval_measures):
        print(f'{meas_name:<30} {mireval_numbers[i]}')
precision                      0.48616802191036285
recall                         0.8176462186674284
f_measure                      0.6097700613790992
cosine_sim                     0.7060966329100041
binary_crossentropy            0.15767848274651708
euclidean_distance             1.4385405760872234
binary_accuracy                0.9470168626259997
soft_accuracy                  0.9417796386451203
accum_energy                   0.6761354310856093
roc_auc_measure                0.9687468215659629
average_precision_score        0.7337572851609266

Precision                      0.48616802191036285
Recall                         0.8176462186674284
Accuracy                       0.4386109408519767
Substitution Error             0.17157037777619083
Miss Error                     0.010783403556380775
False Alarm Error              0.6926015853745626
Total Error                    0.8749553667071341
Chroma Precision               0.5429184093755971
Chroma Recall                  0.913090052131686
Chroma Accuracy                0.5162410416876956
Chroma Substitution Error      0.07612654431193315
Chroma Miss Error              0.010783403556380775
Chroma False Alarm Error       0.6926015853745626
Chroma Total Error             0.7795115332428765